-
Notifications
You must be signed in to change notification settings - Fork 19.6k
Core Data Structures & Communication Primitives for Tensor Parallel for Keras #21697
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Core Data Structures & Communication Primitives for Tensor Parallel for Keras #21697
Conversation
Summary of ChangesHello @buildwithsuhana, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request lays the fundamental groundwork for introducing backend-agnostic auto-sharding and Tensor Parallelism into Keras 3.0. It establishes a modular and extensible architecture by defining core data structures, abstracting distributed backend functionalities, and providing high-level communication primitives. This initial set of changes is crucial for enabling future capabilities that will allow users to train very large models across multiple devices with significantly simplified code. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request lays a solid foundation for tensor parallelism in Keras by introducing backend-agnostic abstractions for distributed operations and core data structures for sharding. The overall design is well-structured, separating concerns between backend-specific implementations, communication primitives, and configuration. However, there are several areas that need attention, particularly regarding the correctness of some backend implementations (especially JAX), placeholder logic, API clarity, and code consistency. Addressing these points will strengthen the foundation and prevent issues in future development.
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## master #21697 +/- ##
==========================================
+ Coverage 82.59% 82.62% +0.02%
==========================================
Files 572 573 +1
Lines 58327 58608 +281
Branches 9131 9161 +30
==========================================
+ Hits 48177 48425 +248
- Misses 7818 7854 +36
+ Partials 2332 2329 -3
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've added a few initial comments and questions during my first look.
To make the review more manageable, I propose we split this change up. At almost 1,800 lines, the current change is quite difficult to review properly. What do you think about limiting this PR to just the JAX backend, and introducing the others in subsequent, smaller PRs?
…uhana/keras into Tensor_parallel_keras
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the PR!
Some high level comments:
- Out of context, it's really hard for me to understand why these abstractions are needed for Tensor Parallel.
- Why do we need all these primitives?
- Why do we need 3 layers of abstraction for the same concepts: the
communications
layer, thestate_actions
layer and thekeras.distributed.get_communication_ops
layer? Can we just have one?
- These abstraction look Torch-like and not JAX-like. On JAX you never have to manually split and do an all-gather, you simply shard. You never have to explicitly have to do a "collective sum". You just do a sum, and if the tensors are sharded, it will magically do all the needed collectives for you. So it's unclear to me why any of these are needed for JAX.
- I wouldn't export these symbols that you added to
keras.distributed
, I don't think they are needed. What we'll expose is the "Tensor Parallel" API. - For the better or worse, we don't do type annotations in Keras. And unfortunately, mixing code with type annotations with code without type annotation doesn't work well. It's better to not have any type annotations at all.
This Pull Request introduces the foundational components for a new, backend-agnostic auto-sharding system in Keras, specifically designed for tensor parallelism. It establishes the core data structures and the JAX-specific implementation of communication primitives.
Core Backend-Agnostic Abstractions
The most significant part of this PR is the creation of a generic, backend-agnostic system for defining sharding plans. This logic resides in keras/src/distribution/tensor_parallel/tensor_layout.py.
JAX-Specific Backend Implementation
This PR provides the first backend-specific implementation of the required distributed communication primitives,
Design Document: Autosharding for Keras
Example usage: https://colab.research.google.com/drive/1UAINIcstDuO0aeA9lxCF5LaIj5ne5X5z?resourcekey=0-pPF4COO19KRoqS5cpWNILA&usp=sharing
The full code of Tensor parallel for Keras has been devided into 4 PRs, this is the first PR for the same.